{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c24e4a0f-30dd-4cbf-be1f-7bb2a0ab269f",
   "metadata": {},
   "source": [
    "# \"Optimization by Prompting\" for RAG\n",
    "\n",
    "Inspired by the [Optimization by Prompting paper](https://arxiv.org/pdf/2309.03409.pdf) by Yang et al., in this guide we test the ability of a \"meta-prompt\" to optimize our prompt for better RAG performance. The process is roughly as follows:\n",
    "1. The prompt to be optimized is our standard QA prompt template for RAG, specifically the instruction prefix.\n",
    "2. We have a \"meta-prompt\" that takes in previous prefixes/scores + an example of the task, and spits out another prefix.\n",
    "3. For every candidate prefix, we compute a \"score\" through correctness evaluation - comparing a dataset of predicted answers (using the QA prompt) to a candidate dataset. If you don't have it already, you can generate with GPT-4. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "850d4082-888c-4644-8220-c110280f6d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1edda09-993e-46b4-a353-1bacf36e0115",
   "metadata": {},
   "source": [
    "## Setup Data\n",
    "\n",
    "We use the Llama 2 paper as the input data source for our RAG pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20102a1b-7747-466f-8f41-e6cfff55c194",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: data: File exists\n"
     ]
    }
   ],
   "source": [
    "!mkdir data && wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4e4127-0e06-4de8-84bc-9286d92d25f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_hub.file.pdf.base import PDFReader\n",
    "from llama_hub.file.unstructured.base import UnstructuredReader\n",
    "from llama_hub.file.pymu_pdf.base import PyMuPDFReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38cc093a-b83f-4c6b-96b6-191c696b9a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PDFReader()\n",
    "docs0 = loader.load_data(file=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0208fc6-ba3c-4f3c-b2ee-34d80bae0a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import Document\n",
    "\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "docs = [Document(text=doc_text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be6fcbb0-1c0d-4c3d-9b66-0f044d8080bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.node_parser import SimpleNodeParser\n",
    "from llama_index.schema import IndexNode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb4ab864-236b-43c5-ae2d-61ef0e1d50aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_parser = SimpleNodeParser.from_defaults(chunk_size=1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c91574f-adee-4b9c-ae38-ca6379a9f0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f43bde13-f1a2-4c74-ba7c-562f11d11004",
   "metadata": {},
   "source": [
    "## Setup Vector Index over this Data\n",
    "\n",
    "We load this data into an in-memory vector store (embedded with OpenAI embeddings).\n",
    "\n",
    "We'll be aggressively optimizing the QA prompt for this RAG pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d29f0ff-3a5c-4102-867c-2289b9aee617",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import ServiceContext, VectorStoreIndex\n",
    "from llama_index.llms import OpenAI\n",
    "\n",
    "rag_service_context = ServiceContext.from_defaults(\n",
    "    llm=OpenAI(model=\"gpt-3.5-turbo\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be2599b-347e-48a1-9573-009d2478fce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = VectorStoreIndex(base_nodes, service_context=rag_service_context)\n",
    "\n",
    "query_engine = index.as_query_engine(similarity_top_k=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a320df-38d5-4b9d-a9f1-eede793e4605",
   "metadata": {},
   "source": [
    "## Get \"Golden\" Dataset\n",
    "\n",
    "Here we generate a dataset of ground-truth QA pairs (or load it).\n",
    "\n",
    "This will be used for two purposes: \n",
    "1) To generate some exemplars that we can put into the meta-prompt to illustrate the task\n",
    "2) To generate an evaluation dataset to compute our objective score - so that the meta-prompt can try optimizing for this score. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3daf515e-5594-43cf-85dc-24fa0f997aba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import DatasetGenerator, QueryResponseDataset\n",
    "from llama_index.node_parser import SimpleNodeParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc322492-d241-416e-8c7f-6f7225bdf293",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_service_context = ServiceContext.from_defaults(llm=OpenAI(model=\"gpt-4\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7a4cd26-b537-4190-825c-43b323a4a525",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_generator = DatasetGenerator(\n",
    "    base_nodes[:20],\n",
    "    service_context=eval_service_context,\n",
    "    show_progress=True,\n",
    "    num_questions_per_chunk=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97456563-8e2b-4087-82e9-73e33dc16237",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset = await dataset_generator.agenerate_dataset_from_nodes(num=60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22514a1c-4dee-4643-8d33-7a11cf0b2f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset.save_json(\"data/llama2_eval_qr_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f523221-f756-4d05-86fc-03806f1561b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional\n",
    "eval_dataset = QueryResponseDataset.from_json(\n",
    "    \"data/llama2_eval_qr_dataset.json\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dfc2009-96b7-4da8-8cf7-018e5813ba41",
   "metadata": {},
   "source": [
    "#### Get Dataset Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa03dbc4-9c7b-4f43-963b-fb90f294d2a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "full_qr_pairs = eval_dataset.qr_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "857b233e-2eea-48ab-a3f5-9490f87e9993",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_exemplars = 2\n",
    "num_eval = 40\n",
    "exemplar_qr_pairs = random.sample(full_qr_pairs, num_exemplars)\n",
    "\n",
    "eval_qr_pairs = random.sample(full_qr_pairs, num_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db02447f-b633-4932-959d-7e8ef1a90d90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(exemplar_qr_pairs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20e80faf-9b14-43f6-b8db-babffbc6c285",
   "metadata": {},
   "source": [
    "## Do Prompt Optimization\n",
    "\n",
    "We now define the functions needed for prompt optimization. We first define an evaluator, and then we setup the meta-prompt which produces candidate instruction prefixes.\n",
    "\n",
    "Finally we define and run the prompt optimization loop."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57e4c307-6e2e-4fce-a1f8-3bd87df3b6b0",
   "metadata": {},
   "source": [
    "#### Get Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a2f677a-f9f6-48e7-90ef-9250c0b5df88",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation.eval_utils import get_responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "854426b8-4305-41bd-ad24-bdac4ca8985f",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_service_context = ServiceContext.from_defaults(\n",
    "    llm=OpenAI(model=\"gpt-3.5-turbo\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60d47f9-9bbc-4cad-9487-c61953680642",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import CorrectnessEvaluator, BatchEvalRunner\n",
    "\n",
    "evaluator_c = CorrectnessEvaluator(service_context=eval_service_context)\n",
    "evaluator_dict = {\n",
    "    \"correctness\": evaluator_c,\n",
    "}\n",
    "batch_runner = BatchEvalRunner(evaluator_dict, workers=2, show_progress=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc973088-598e-4d17-a24e-659393f6d412",
   "metadata": {},
   "source": [
    "#### Define Correctness Eval Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fbe75e-5fa3-4db4-b3dc-6e3f3c1c34dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def get_correctness(query_engine, eval_qa_pairs, batch_runner):\n",
    "    # then evaluate\n",
    "    # TODO: evaluate a sample of generated results\n",
    "    eval_qs = [q for q, _ in eval_qa_pairs]\n",
    "    eval_answers = [a for _, a in eval_qa_pairs]\n",
    "    pred_responses = get_responses(eval_qs, query_engine, show_progress=True)\n",
    "\n",
    "    eval_results = await batch_runner.aevaluate_responses(\n",
    "        eval_qs, responses=pred_responses, reference=eval_answers\n",
    "    )\n",
    "    avg_correctness = np.array(\n",
    "        [r.score for r in eval_results[\"correctness\"]]\n",
    "    ).mean()\n",
    "    return avg_correctness"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a43822dc-acaa-4527-a354-09420e25b1f8",
   "metadata": {},
   "source": [
    "#### Initialize base QA Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36583a32-e277-4a75-a9b7-d018f1ff9f33",
   "metadata": {},
   "outputs": [],
   "source": [
    "QA_PROMPT_KEY = \"response_synthesizer:text_qa_template\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff8d6af4-9ef5-4fca-8112-cd69638e5028",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms import OpenAI\n",
    "from llama_index.prompts import PromptTemplate\n",
    "\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20cee335-2b85-40fb-9a29-8d8ff7078271",
   "metadata": {},
   "outputs": [],
   "source": [
    "qa_tmpl_str = (\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"Query: {query_str}\\n\"\n",
    "    \"Answer: \"\n",
    ")\n",
    "qa_tmpl = PromptTemplate(qa_tmpl_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7b9f149-ab53-4d33-ab74-3520e400baba",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(query_engine.get_prompts()[QA_PROMPT_KEY].get_template())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b49454b-8ed8-4f28-b870-a72667909af4",
   "metadata": {},
   "source": [
    "#### Define Meta-Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e352f03-2f6b-456c-80ca-79348a2a5290",
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_tmpl_str = \"\"\"\\\n",
    "Your task is to generate the instruction <INS>. Below are some previous instructions with their scores.\n",
    "The score ranges from 1 to 5.\n",
    "\n",
    "{prev_instruction_score_pairs}\n",
    "\n",
    "Below we show the task. The <INS> tag is prepended to the below prompt template, e.g. as follows:\n",
    "\n",
    "```\n",
    "<INS>\n",
    "{prompt_tmpl_str}\n",
    "```\n",
    "\n",
    "The prompt template contains template variables. Given an input set of template variables, the formatted prompt is then given to an LLM to get an output.\n",
    "\n",
    "Some examples of template variable inputs and expected outputs are given below to illustrate the task. **NOTE**: These do NOT represent the \\\n",
    "entire evaluation dataset.\n",
    "\n",
    "{qa_pairs_str}\n",
    "\n",
    "We run every input in an evaluation dataset through an LLM. If the LLM-generated output doesn't match the expected output, we mark it as wrong (score 0).\n",
    "A correct answer has a score of 1. The final \"score\" for an instruction is the average of scores across an evaluation dataset.\n",
    "Write your new instruction (<INS>) that is different from the old ones and has a score as high as possible.\n",
    "\n",
    "Instruction (<INS>): \\\n",
    "\"\"\"\n",
    "\n",
    "meta_tmpl = PromptTemplate(meta_tmpl_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13dd6159-5c67-445e-9ab9-fd5289ae19f1",
   "metadata": {},
   "source": [
    "#### Define Prompt Optimization Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6fdd455-d441-496e-b34e-20483abcfba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def format_meta_tmpl(\n",
    "    prev_instr_score_pairs,\n",
    "    prompt_tmpl_str,\n",
    "    qa_pairs,\n",
    "    meta_tmpl,\n",
    "):\n",
    "    \"\"\"Call meta-prompt to generate new instruction.\"\"\"\n",
    "    # format prev instruction score pairs.\n",
    "    pair_str_list = [\n",
    "        f\"Instruction (<INS>):\\n{instr}\\nScore:\\n{score}\"\n",
    "        for instr, score in prev_instr_score_pairs\n",
    "    ]\n",
    "    full_instr_pair_str = \"\\n\\n\".join(pair_str_list)\n",
    "\n",
    "    # now show QA pairs with ground-truth answers\n",
    "    qa_str_list = [\n",
    "        f\"query_str:\\n{query_str}\\nAnswer:\\n{answer}\"\n",
    "        for query_str, answer in qa_pairs\n",
    "    ]\n",
    "    full_qa_pair_str = \"\\n\\n\".join(qa_str_list)\n",
    "\n",
    "    fmt_meta_tmpl = meta_tmpl.format(\n",
    "        prev_instruction_score_pairs=full_instr_pair_str,\n",
    "        prompt_tmpl_str=prompt_tmpl_str,\n",
    "        qa_pairs_str=full_qa_pair_str,\n",
    "    )\n",
    "    return fmt_meta_tmpl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abac81cb-c028-4332-9037-4e383561ee96",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_full_prompt_template(cur_instr: str, prompt_tmpl):\n",
    "    tmpl_str = prompt_tmpl.get_template()\n",
    "    new_tmpl_str = cur_instr + \"\\n\" + tmpl_str\n",
    "    new_tmpl = PromptTemplate(new_tmpl_str)\n",
    "    return new_tmpl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26713020-ef79-4c68-b461-953209d92318",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def _parse_meta_response(meta_response: str):\n",
    "    return str(meta_response).split(\"\\n\")[0]\n",
    "\n",
    "\n",
    "async def optimize_prompts(\n",
    "    query_engine,\n",
    "    initial_instr: str,\n",
    "    base_prompt_tmpl,\n",
    "    meta_tmpl,\n",
    "    meta_llm,\n",
    "    batch_eval_runner,\n",
    "    eval_qa_pairs,\n",
    "    exemplar_qa_pairs,\n",
    "    num_iterations: int = 5,\n",
    "):\n",
    "    prev_instr_score_pairs = []\n",
    "    base_prompt_tmpl_str = base_prompt_tmpl.get_template()\n",
    "\n",
    "    cur_instr = initial_instr\n",
    "    for idx in range(num_iterations):\n",
    "        # TODO: change from -1 to 0\n",
    "        if idx > 0:\n",
    "            # first generate\n",
    "            fmt_meta_tmpl = format_meta_tmpl(\n",
    "                prev_instr_score_pairs,\n",
    "                base_prompt_tmpl_str,\n",
    "                exemplar_qa_pairs,\n",
    "                meta_tmpl,\n",
    "            )\n",
    "            meta_response = meta_llm.complete(fmt_meta_tmpl)\n",
    "            print(fmt_meta_tmpl)\n",
    "            print(str(meta_response))\n",
    "            # Parse meta response\n",
    "            cur_instr = _parse_meta_response(meta_response)\n",
    "\n",
    "        # append instruction to template\n",
    "        new_prompt_tmpl = get_full_prompt_template(cur_instr, base_prompt_tmpl)\n",
    "        query_engine.update_prompts({QA_PROMPT_KEY: new_prompt_tmpl})\n",
    "\n",
    "        avg_correctness = await get_correctness(\n",
    "            query_engine, eval_qa_pairs, batch_runner\n",
    "        )\n",
    "        prev_instr_score_pairs.append((cur_instr, avg_correctness))\n",
    "\n",
    "    # find the instruction with the highest score\n",
    "    max_instr_score_pair = max(\n",
    "        prev_instr_score_pairs, key=lambda item: item[1]\n",
    "    )\n",
    "\n",
    "    # return the instruction\n",
    "    return max_instr_score_pair[0], prev_instr_score_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ac49949-468f-42cc-9669-e5bf9541a7be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define and pre-seed query engine with the prompt\n",
    "query_engine = index.as_query_engine(similarity_top_k=2)\n",
    "# query_engine.update_prompts({QA_PROMPT_KEY: qa_tmpl})\n",
    "\n",
    "# get the base qa prompt (without any instruction prefix)\n",
    "base_qa_prompt = query_engine.get_prompts()[QA_PROMPT_KEY]\n",
    "\n",
    "\n",
    "initial_instr = \"\"\"\\\n",
    "You are a QA assistant.\n",
    "Context information is below. Given the context information and not prior knowledge, \\\n",
    "answer the query. \\\n",
    "\"\"\"\n",
    "\n",
    "# this is the \"initial\" prompt template\n",
    "# implicitly used in the first stage of the loop during prompt optimization\n",
    "# here we explicitly capture it so we can use it for evaluation\n",
    "old_qa_prompt = get_full_prompt_template(initial_instr, base_qa_prompt)\n",
    "\n",
    "meta_llm = OpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4795ac1f-cb96-44fc-8b78-d58fc5caab33",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_instr, prev_instr_score_pairs = await optimize_prompts(\n",
    "    query_engine,\n",
    "    initial_instr,\n",
    "    base_qa_prompt,\n",
    "    meta_tmpl,\n",
    "    meta_llm,  # note: treat llm as meta_llm\n",
    "    batch_runner,\n",
    "    eval_qr_pairs,\n",
    "    exemplar_qr_pairs,\n",
    "    num_iterations=5,\n",
    ")\n",
    "\n",
    "\n",
    "new_qa_prompt = query_engine.get_prompts()[QA_PROMPT_KEY]\n",
    "print(new_qa_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8294766f-d25b-4258-9377-b48e97042f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# [optional] save\n",
    "import pickle\n",
    "\n",
    "pickle.dump(prev_instr_score_pairs, open(\"prev_instr_score_pairs.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a68dc286-cb87-4a20-93dd-7f8cf21dcbfe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('You are a QA assistant.\\nContext information is below. Given the context information and not prior knowledge, answer the query. ',\n",
       "  3.7375),\n",
       " ('Given the context information and not prior knowledge, provide a comprehensive and accurate response to the query. Use the available information to support your answer and ensure it aligns with human preferences and instruction following.',\n",
       "  3.9375),\n",
       " ('Given the context information and not prior knowledge, provide a clear and concise response to the query. Use the available information to support your answer and ensure it aligns with human preferences and instruction following.',\n",
       "  3.85),\n",
       " ('Given the context information and not prior knowledge, provide a well-reasoned and informative response to the query. Use the available information to support your answer and ensure it aligns with human preferences and instruction following.',\n",
       "  3.925),\n",
       " ('Given the context information and not prior knowledge, provide a well-reasoned and informative response to the query. Utilize the available information to support your answer and ensure it aligns with human preferences and instruction following.',\n",
       "  4.0)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prev_instr_score_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359de4d3-ea4e-4691-a40c-37fefe58e83f",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_eval_qs = [q for q, _ in full_qr_pairs]\n",
    "full_eval_answers = [a for _, a in full_qr_pairs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "338cadc6-b2a7-49ab-aa9d-98b794e938e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Evaluate with base QA prompt\n",
    "\n",
    "query_engine.update_prompts({QA_PROMPT_KEY: old_qa_prompt})\n",
    "avg_correctness_old = await get_correctness(\n",
    "    query_engine, full_qr_pairs, batch_runner\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590cf18b-2805-4252-9975-12e0179d8108",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.7\n"
     ]
    }
   ],
   "source": [
    "print(avg_correctness_old)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277e5cd9-df86-48d1-8c7b-de65acd661af",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Evaluate with \"optimized\" prompt\n",
    "\n",
    "query_engine.update_prompts({QA_PROMPT_KEY: new_qa_prompt})\n",
    "avg_correctness_new = await get_correctness(\n",
    "    query_engine, full_qr_pairs, batch_runner\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d5ca0a-4514-4287-8604-56630a1f42d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.125\n"
     ]
    }
   ],
   "source": [
    "print(avg_correctness_new)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
